from typing import Any, Optional, Union

import numpy as np

from gymnasium.spaces import Discrete

from tianshou.data import Batch
from tianshou.policy import PPOPolicy
from Policy.fpg import FPGPolicy


class PPOPolicyMaskEnabled(PPOPolicy):

    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> Batch:
        """
        mask only applicable to Discrete action space
        add a mask by putting it in the obs as batch.obs.mask
        """
        logits, hidden = self.actor(batch.obs, state=state, info=batch.info)

        # changes here: mask out actions that are not available
        if isinstance(batch.obs, Batch) and "mask" in batch.obs:
            mask = batch.obs.mask
            assert isinstance(self.action_space, Discrete)
            logits = logits.masked_fill(~mask.bool(), float("-inf"))

        if isinstance(logits, tuple):
            dist = self.dist_fn(*logits)
        else:
            dist = self.dist_fn(logits)
        if self._deterministic_eval and not self.training:
            if self.action_type == "discrete":
                act = logits.argmax(-1)
            elif self.action_type == "continuous":
                act = logits[0]
        else:
            act = dist.sample()
        return Batch(logits=logits, act=act, state=hidden, dist=dist)


class FPGPolicyMaskEnabled(FPGPolicy):

    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> Batch:
        """
        mask only applicable to Discrete action space
        add a mask by putting it in the obs as batch.obs.mask
        """
        logits, hidden = self.actor(batch.obs, state=state, info=batch.info)

        # changes here: mask out actions that are not available
        if isinstance(batch.obs, Batch) and "mask" in batch.obs:
            mask = batch.obs.mask
            assert isinstance(self.action_space, Discrete)
            logits = logits.masked_fill(~mask.bool(), float("-inf"))

        if isinstance(logits, tuple):
            dist = self.dist_fn(*logits)
        else:
            dist = self.dist_fn(logits)
        if self._deterministic_eval and not self.training: 
            if self.action_type == "discrete":
                act = logits.argmax(-1)
            elif self.action_type == "continuous":
                act = logits[0]
        else:
            act = dist.sample()
        return Batch(logits=logits, act=act, state=hidden, dist=dist)
